import json
import time
from experiments.baselines.random_search import random_search_baseline
from experiments.baselines.grid_search import grid_search_baseline
from experiments.baselines.tpot_baseline import tpot_baseline
from experiments.baselines.autosklearn_baseline import autosklearn_baseline
from experiments.marl.run_marl import run_marl

DATASETS = ["iris", "adult", "covtype", "credit-g", "bank-marketing"]


def run_benchmarks(episodes: int = 100, random_iters: int = 50, grid_max: int = 200,
                   tpot_mins: int = 5, askl_time: int = 300):
    results = {}
    for ds in DATASETS:
        print(f"\n=== Dataset: {ds} ===")
        results[ds] = {}
        try:
            results[ds]["random_search"] = random_search_baseline(ds, n_iter=random_iters)
        except Exception as e:
            results[ds]["random_search"] = {"error": str(e)}
        try:
            results[ds]["grid_search"] = grid_search_baseline(ds, max_configs=grid_max)
        except Exception as e:
            results[ds]["grid_search"] = {"error": str(e)}
        try:
            results[ds]["tpot"] = tpot_baseline(ds, max_time_mins=tpot_mins)
        except Exception as e:
            results[ds]["tpot"] = {"error": str(e)}
        try:
            results[ds]["auto_sklearn"] = autosklearn_baseline(ds, time_left_for_this_task=askl_time)
        except Exception as e:
            results[ds]["auto_sklearn"] = {"error": str(e)}
        try:
            results[ds]["marl"] = run_marl(ds, episodes=episodes)
        except Exception as e:
            results[ds]["marl"] = {"error": str(e)}
    return results


if __name__ == "__main__":
    import argparse
    p = argparse.ArgumentParser()
    p.add_argument("--episodes", type=int, default=100)
    p.add_argument("--random_iters", type=int, default=50)
    p.add_argument("--grid_max", type=int, default=200)
    p.add_argument("--tpot_mins", type=int, default=5)
    p.add_argument("--askl_time", type=int, default=300)
    args = p.parse_args()
    res = run_benchmarks(args.episodes, args.random_iters, args.grid_max, args.tpot_mins, args.askl_time)
    print(json.dumps(res, indent=2))
